import torch
import cv2
from matplotlib import pyplot as plt
from torch.nn import functional as F

a = torch.randn(2, 5, 5,5)
lg, ind = torch.max(a, dim=1, keepdim=True)

print(lg)
print(ind)


ig = cv2.imread('F:/1219/JPEGImages/5e09de276492416799c81cd67ad3d497_315_1_1682638265528_14.jpg')
ig = cv2.cvtColor(ig, cv2.COLOR_BGR2RGB)
ig = ig/255.0

ig = torch.from_numpy(ig)
ig = torch.unsqueeze(ig, 0)
ig = torch.permute(ig, (0, 3, 1,2))

for i in range(10):
    ig = F.avg_pool2d(ig, kernel_size=3, padding=1, stride=1)


ig = torch.permute(ig, (0, 2, 3, 1))
ig = ig.numpy()
ig = ig[0]
plt.imshow(ig)
plt.show()